{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f29588be-fbc2-4101-8734-661a716ea1c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "from mindspore.common.initializer import initializer\n",
    "\n",
    "import mindnlp\n",
    "from mindnlp import Vocab\n",
    "from mindspore.dataset.transforms import PadEnd\n",
    "from mindnlp.transforms import BasicTokenizer, PadTransform, Lookup\n",
    "from mindnlp.modules import Glove, StaticLSTM\n",
    "from mindnlp.metrics import accuracy_fn\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "daab0038-dc5d-4b0f-ab3b-5e8793c7f5e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the module from /home/daiyuxin/.cache/huggingface/modules/datasets_modules/datasets/imdb/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0 (last modified on Sun Oct  8 16:06:39 2023) since it couldn't be found locally at imdb., or remotely on the Hugging Face Hub.\n",
      "Found cached dataset imdb (/data1/lvyufeng/mindnlp/examples/classification/.mindnlp/datasets/imdb/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 419.56it/s]\n"
     ]
    }
   ],
   "source": [
    "imdb_ds = mindnlp.load_dataset('imdb', split=['train', 'test'], shuffle=True)\n",
    "imdb_train, imdb_test = imdb_ds['train'], imdb_ds['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c6caea8f-e16f-4df9-a275-d5cb51c4b817",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BasicTokenizer(lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "35455298-838d-4e6d-9317-e0fd39e56241",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocab.from_pretrained(name=\"glove.6B.100d\")\n",
    "lookup_op = Lookup(vocab, unk_token='<unk>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4719b5b9-b413-47b1-b0df-9b357a5d6b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 256\n",
    "pad_op = PadTransform(max_length, pad_value=vocab('<pad>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9bae9d35-ba6b-40d7-9b7d-6d91ee2c5f46",
   "metadata": {},
   "outputs": [],
   "source": [
    "imdb_train = imdb_train.map([tokenizer, lookup_op, pad_op], 'text')\n",
    "imdb_test = imdb_test.map([tokenizer, lookup_op, pad_op], 'text')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "74e2d623-3c78-4852-8dde-9429853ad4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "\n",
    "imdb_train = imdb_train.batch(batch_size)\n",
    "imdb_test = imdb_test.batch(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "991d9198-51fb-4f75-85fc-296278d916ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(108628:139630022588224,MainProcess):2023-10-08-18:01:44.148.917 [mindspore/dataset/engine/datasets.py:1200] Dataset is shuffled before split.\n"
     ]
    }
   ],
   "source": [
    "imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4f07b456-0277-4387-b438-7f7b1a083b5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "from mindspore.common.initializer import Uniform, HeUniform\n",
    "\n",
    "class RNN(nn.Cell):\n",
    "    def __init__(self, embedding, hidden_dim, output_dim, n_layers,\n",
    "                 bidirectional):\n",
    "        super().__init__()\n",
    "        embedding_dim = embedding._embed_dim\n",
    "        self.embedding = embedding\n",
    "        self.rnn = StaticLSTM(embedding_dim,\n",
    "                              hidden_dim,\n",
    "                              num_layers=n_layers,\n",
    "                              bidirectional=bidirectional,\n",
    "                              batch_first=True,\n",
    "                              dropout=0.5)\n",
    "        self.fc = nn.Dense(hidden_dim * 2, output_dim)\n",
    "\n",
    "    def construct(self, inputs):\n",
    "        embedded = self.embedding(inputs)\n",
    "        _, (hidden, _) = self.rnn(embedded)\n",
    "        hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)\n",
    "        output = self.fc(hidden)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "206e4277-af8e-430b-9533-1a612daf0a80",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 822M/822M [02:38<00:00, 5.43MB/s]\n"
     ]
    }
   ],
   "source": [
    "# load embedding and vocab\n",
    "embedding = Glove.from_pretrained('6B', 100, special_tokens=[\"<unk>\", \"<pad>\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "551c0b0b-d116-4c72-825e-7d022506a2ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_size = 256\n",
    "output_size = 2\n",
    "num_layers = 2\n",
    "bidirectional = True\n",
    "lr = 5e-4\n",
    "\n",
    "model = RNN(embedding, hidden_size, output_size, num_layers, bidirectional)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bea9aa5c-2dc0-455c-b780-1c011fa4d811",
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_weights(m):\n",
    "    if isinstance(m, nn.Dense):\n",
    "        m.weight.set_data(initializer('xavier_normal', m.weight.shape, m.weight.dtype))\n",
    "        m.bias.set_data(initializer('zeros', m.bias.shape, m.bias.dtype))\n",
    "    elif isinstance(m, StaticLSTM):\n",
    "        for name, param in m.parameters_and_names():\n",
    "            if 'bias' in name:\n",
    "                param.set_data(initializer('zeros', param.shape, param.dtype))\n",
    "            elif 'weight' in name:\n",
    "                param.set_data(initializer('orthogonal', param.shape, param.dtype))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "af269446-61cf-4093-9c44-11277f0ba6cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RNN<\n",
       "  (embedding): Glove<\n",
       "    (dropout_layer): Dropout<>\n",
       "    >\n",
       "  (rnn): StaticLSTM<\n",
       "    (rnn): MultiLayerRNN<\n",
       "      (cell_list): CellList<\n",
       "        (0): SingleLSTMLayer_GPU<>\n",
       "        (1): SingleLSTMLayer_GPU<>\n",
       "        >\n",
       "      (dropout): Dropout<p=0.5>\n",
       "      >\n",
       "    >\n",
       "  (fc): Dense<input_channels=512, output_channels=2, has_bias=True>\n",
       "  >"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.apply(initialize_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "66750264-9d22-4f47-aec1-8e49fb6db0ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_fn(data, label):\n",
    "    logits = model(data)\n",
    "    loss = loss_fn(logits, label)\n",
    "    return loss\n",
    "\n",
    "grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)\n",
    "\n",
    "def train_step(data, label):\n",
    "    loss, grads = grad_fn(data, label)\n",
    "    optimizer(grads)\n",
    "    return loss\n",
    "\n",
    "def train_one_epoch(model, train_dataset, epoch=0):\n",
    "    model.set_train()\n",
    "    total = train_dataset.get_dataset_size()\n",
    "    loss_total = 0\n",
    "    step_total = 0\n",
    "    with tqdm(total=total) as t:\n",
    "        t.set_description('Epoch %i' % epoch)\n",
    "        for data, label in train_dataset.create_tuple_iterator():\n",
    "            loss = train_step(data, label.astype(mindspore.int32))\n",
    "            loss_total += loss.asnumpy()\n",
    "            step_total += 1\n",
    "            t.set_postfix(loss=loss_total/step_total)\n",
    "            t.update(1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3ab6c926-b5c9-4ddc-84cc-4282d24a6851",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_dataset, criterion, epoch=0):\n",
    "    total = test_dataset.get_dataset_size()\n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    step_total = 0\n",
    "    model.set_train(False)\n",
    "\n",
    "    with tqdm(total=total) as t:\n",
    "        t.set_description('Epoch %i' % epoch)\n",
    "        for i in test_dataset.create_tuple_iterator():\n",
    "            predictions = model(i[0])\n",
    "            loss = criterion(predictions, i[1].astype(mindspore.int32))\n",
    "            epoch_loss += loss.asnumpy()\n",
    "\n",
    "            acc = accuracy_fn(predictions, i[1])\n",
    "            epoch_acc += acc\n",
    "\n",
    "            step_total += 1\n",
    "            t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)\n",
    "            t.update(1)\n",
    "\n",
    "    return epoch_loss / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "63810da0-d866-4a9f-896f-d50441996831",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████| 274/274 [00:18<00:00, 14.58it/s, loss=0.654]\n",
      "Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████| 117/117 [00:06<00:00, 18.92it/s, acc=0.692, loss=0.586]\n",
      "Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 274/274 [00:14<00:00, 18.60it/s, loss=0.641]\n",
      "Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 117/117 [00:07<00:00, 15.82it/s, acc=0.696, loss=0.607]\n",
      "Epoch 2:   0%|                                                                                                          | 0/274 [00:00<?, ?it/s][WARNING] PYNATIVE(108628,7efe25e57740,python):2023-10-08-18:05:43.310.834 [mindspore/ccsrc/pipeline/pynative/grad/grad.cc:1199] CheckAlreadyRun] The input info T184309_T184311_ is not the same with pre input info T184187_, forward process will run again\n",
      "Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████| 274/274 [00:25<00:00, 10.96it/s, loss=0.58]\n",
      "Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 117/117 [00:06<00:00, 19.08it/s, acc=0.628, loss=0.641]\n",
      "Epoch 3:   0%|                                                                                                          | 0/274 [00:00<?, ?it/s][WARNING] PYNATIVE(108628,7efe25e57740,python):2023-10-08-18:06:12.844.055 [mindspore/ccsrc/pipeline/pynative/grad/grad.cc:1199] CheckAlreadyRun] The input info T275884_T275886_ is not the same with pre input info T275762_, forward process will run again\n",
      "Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████| 274/274 [00:14<00:00, 18.76it/s, loss=0.541]\n",
      "Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████| 117/117 [00:06<00:00, 18.64it/s, acc=0.697, loss=0.573]\n",
      "Epoch 4:   0%|                                                                                                          | 0/274 [00:00<?, ?it/s][WARNING] PYNATIVE(108628,7efe25e57740,python):2023-10-08-18:06:34.315.342 [mindspore/ccsrc/pipeline/pynative/grad/grad.cc:1199] CheckAlreadyRun] The input info T367992_T367994_ is not the same with pre input info T367851_, forward process will run again\n",
      "Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████| 274/274 [00:14<00:00, 18.39it/s, loss=0.45]\n",
      "Epoch 4: 100%|██████████████████████████████████████████████████████████████████████████| 117/117 [00:06<00:00, 19.08it/s, acc=0.85, loss=0.353]\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 5\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    train_one_epoch(model, imdb_train, epoch)\n",
    "    valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)\n",
    "\n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        ms.save_checkpoint(model, './sentiment_analysis.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b40939a9-5d26-4c84-b1fb-38d98eab46a4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
